#pragma once
#include <array>
#include <sys/cdefs.h>
#include <tuple>
#include <cassert>
#include "cell.h"
#include "esmd_types.h"
#include "treduce.hpp"
#include "utils.hpp"
#include "unitconv.hpp"
#include "list_cpp.hpp"
#include "htab_cpp.hpp"
#include "dimension.h"

struct mpp_t;
struct cellgrid_t;
struct celldata_t;
#define DEFTIE(func, ...)                       \
  auto func() { return std::tie(__VA_ARGS__); } \
  auto func() const { return std::make_tuple(__VA_ARGS__); }

template <typename T, int N>
struct vecref {
  T *refs[N];
  template <typename... Ts>
  vecref(vec<Ts> &...args) : refs{&args...} {
  }
  T &operator[](int i) const {
    return *refs[i];
  }
};

template <typename T>
struct array2ptr_helper{
  typedef T type;
};
template <typename T, int N>
struct array2ptr_helper<T[N]>{
  typedef T *type;
};
template <typename T> using array2ptr = typename array2ptr_helper<T>::type;
template <int N>
struct fref{
  vec<real> val[N];
  template<typename ...Ts>
  fref(vec<Ts> &...args){
    for (int i = 0; i < N; i ++) val[i] = 0;
  }
  vec<real> &operator[](int i) {
    return val[i];
  }
  ~fref(){
  }
};
template <int N>
struct xref{
  const vec<real> val[N];
  template<typename ...Ts>
  xref(vec<Ts> &...args) : val{args...}{
  }
  const vec<real> &operator[](int i) {
    return val[i];
  }
};

/* We want to map 1-4 pairs via a central atom, so that the number of ghost atoms can be reduced */
struct routed_ljcoul_param {
  static constexpr int nbodies = 3;
  static constexpr int owner = 1;
  static constexpr int cell_cap = MAX_TORI_CALC_CELL;
  real sigma, epsilon, qq;
  DEFTIE(tie, sigma, epsilon)
};

struct ljcoul_param {
  static constexpr int nbodies = 2;
  static constexpr int owner = 0;
  static constexpr int cell_cap = MAX_TORI_CALC_CELL;
  real sigma, epsilon, qq;
  DEFTIE(tie, sigma, epsilon)
};

struct harmonic_bond_param {
  static constexpr int nbodies = 2;
  static constexpr int owner = 0;
  static constexpr int cell_cap = MAX_BONDED_CELL;
  real r0, kr;
  DEFTIE(tie, r0, kr);
  __always_inline real std_len(){return r0;};
  static constexpr auto units_gmx = std::make_tuple(units::nm, units::gmx_K(0.5, -2, 0));
  __always_inline real calc(fref<nbodies> &f, xref<nbodies> &x);
};
struct harmonic_angle_param {
  static constexpr int nbodies = 3;
  static constexpr int owner = 1;
  static constexpr int cell_cap = MAX_ANGLE_CALC_CELL;
  real theta0, ktheta;
  DEFTIE(tie, theta0, ktheta);
  static constexpr auto units_gmx = std::make_tuple(units::deg, units::gmx_K(0.5, 0, -2));
  __always_inline real calc(fref<nbodies> &f, xref<nbodies> &x);
};
struct urey_bradly_param {
  static constexpr int nbodies = 3;
  static constexpr int owner = 1;
  static constexpr int cell_cap = MAX_ANGLE_CALC_CELL;
  real theta0, ktheta, r0ub, kub;
  DEFTIE(tie, theta0, ktheta, r0ub, kub)
  static constexpr auto units_gmx = std::make_tuple(units::deg, units::gmx_K(0.5, 0, -2), units::nm, units::gmx_K(0.5, -2, 0));
  __always_inline real calc(fref<nbodies> &f, xref<nbodies> &x);
};
struct cosine_torsion_param {
  static constexpr int nbodies = 4;
  static constexpr int owner = 1;
  static constexpr int cell_cap = MAX_TORI_CALC_CELL;
  real phi0, kphi;
  int np;
  DEFTIE(tie, phi0, kphi, np)
  static constexpr auto units_gmx = std::make_tuple(units::deg, units::gmx_K(1, 0, 0), units::one);
  __always_inline real calc(fref<nbodies> &f, xref<nbodies> &x);
};
struct ryckaert_bellmans_param {
  static constexpr int nbodies = 4;
  static constexpr int owner = 1;
  static constexpr int cell_cap = MAX_TORI_CALC_CELL;
  real c0, c1, c2, c3, c4, c5;
  DEFTIE(tie, c0, c1, c2, c3, c4, c5)
  static constexpr auto units_gmx = std::make_tuple(units::gmx_K(1, 0, 0), units::gmx_K(1, 0, 0), units::gmx_K(1, 0, 0), units::gmx_K(1, 0, 0), units::gmx_K(1, 0, 0), units::gmx_K(1, 0, 0));
  __always_inline real calc(fref<nbodies> &f, xref<nbodies> &x);
};
struct extended_ryckaert_bellmans_param {
  static constexpr int nbodies = 4;
  static constexpr int owner = 1;
  static constexpr int cell_cap = MAX_TORI_CALC_CELL;
  real c0, c1, c2, c3, c4, c5, c6;
  DEFTIE(tie, c0, c1, c2, c3, c4, c5, c6)
  static constexpr auto units_gmx = std::make_tuple(units::gmx_K(1, 0, 0), units::gmx_K(1, 0, 0), units::gmx_K(1, 0, 0), units::gmx_K(1, 0, 0), units::gmx_K(1, 0, 0), units::gmx_K(1, 0, 0), units::gmx_K(1, 0, 0));
  __always_inline extended_ryckaert_bellmans_param &operator+=(const cosine_torsion_param &p);
  __always_inline real calc(fref<nbodies> &f, xref<nbodies> &x);
};
struct harmonic_improper_param {
  static constexpr int nbodies = 4;
  static constexpr int owner = 0;
  static constexpr int cell_cap = MAX_IMPR_CALC_CELL;
  real psi0, kpsi;
  DEFTIE(tie, psi0, kpsi);
  static constexpr auto units_gmx = std::make_tuple(units::deg, units::gmx_K(0.5, 0, -2));
  __always_inline real calc(fref<nbodies> &f, xref<nbodies> &x);
};
struct rigid_param{
  static constexpr int nbodies = 4;
  static constexpr int owner = 0;
  static constexpr int cell_cap = CELL_CAP;
  int type;
  real r0[3];
  DEFTIE(tie, r0[0], r0[1], r0[2], type);
};
struct special_pair {
  static constexpr int nbodies = 2;
  static constexpr int owner = 0;
  static constexpr int cell_cap = 1360;
  enum {
    EXCL = 0,
    SCAL = 1
  };
  DEFTIE(tie);
};
template <class ParamType, typename IdType>
struct listed_force_entry {
  IdType id[ParamType::nbodies];
  int pid;
  static constexpr auto iseq = utils::make_iseq<int, ParamType::nbodies>{};
  template <typename Tk, int... Is>
  static listed_force_entry from_tuple(Tk &tpl, int pid, utils::iseq<int, Is...> seq) {
    return listed_force_entry{{std::get<Is>(tpl)...}, pid};
  }
  template <typename Tk>
  static listed_force_entry from_tuple(Tk &tpl, int pid) {
    return from_tuple(tpl, pid, iseq);
  }
  void shift(IdType value){
    for (int i = 0; i < ParamType::nbodies; i ++){
      id[i] += value;
    }
  }
  bool operator<(const listed_force_entry<ParamType, IdType> &o) const {
    return id[ParamType::owner] < o.id[ParamType::owner];
  }

};

#include "buffer.hpp"
template <typename ParamType>
using listed_force_tag_entry = listed_force_entry<ParamType, tagint>;
template <typename ParamType>
struct listed_force {
  ParamType *param;
  int nparam;
  esmd::list<listed_force_tag_entry<ParamType>> topo;
  listed_force() : topo("Anon listed force"){};
};
template <typename ParamType, typename IdType, int Capacity>
struct listed_force_cell_ents {
  size_t cnt, nexport;
  listed_force_entry<ParamType, IdType> entries[Capacity];
  static constexpr int cap = Capacity;
  static constexpr auto body_seq = utils::make_iseq<int, ParamType::nbodies>{};
  template <typename TagMapType>
  void export_entries(const TagMapType &tag_map);
  void pack_exported(pack_buffer &buffer);
  void unpack_exported(unpack_buffer &buffer);
  template <typename TagMapType>
  void import_entries(const TagMapType &tag_map, listed_force_cell_ents<ParamType, IdType, Capacity> &from);
  template<typename GuestMapType, typename LocalMapType>
  void find_guests(GuestMapType &guests, LocalMapType &locals);
  template<int ...Is>
  real calc(vec<real> *f, vec<real> *x, ParamType *params, utils::iseq<int, Is...> seq);
  real calc(vec<real> *f, vec<real> *x, ParamType *params);

  listed_force_entry<ParamType, IdType> *begin() {
    return entries;
  }
  listed_force_entry<ParamType, IdType> *end() {
    return entries + cnt;
  }
  size_t effective_size(){
    return (size_t)&entries[cnt] - (size_t)this;
  }
};
template <typename ParamType, int Capacity>
struct listed_force_cell {
  static constexpr int cap = Capacity;
  listed_force_cell_ents<ParamType, int, Capacity> by_id;
  listed_force_cell_ents<ParamType, tagint, Capacity> by_tag;
  static constexpr auto body_seq = utils::make_iseq<int, ParamType::nbodies>{};
  template <typename TagMapType, int... Is>
  void relink(const TagMapType &tag_map, utils::iseq<int, Is...> seq);
  template <typename TagMapType>
  void relink(const TagMapType &tag_map);
  // template<int ...Is>
  // real calc(vec<real> *f, vec<real> *x, ParamType *params, utils::iseq<int, Is...> seq);
  real calc(vec<real> *f, vec<real> *x, ParamType *params);

  // template<typename GuestMapType, typename LocalMapType>
  // void find_guests(GuestMapType &guests, LocalMapType &locals);
  //real calc(vec<real> *f, vec<real> *x, ParamType *params, typename ParamType::shared_conf_t &shared_conf);
  
};

template <typename ParamType, int Capacity>
struct listed_force_grid {
  ParamType *param;
  int nparam;
  listed_force_cell<ParamType, Capacity> *cells;

  void allocate(cellgrid_t &grid);
  void distribute(esmd::htab<tag_key, vec<int>> &tag_cell_map, listed_force<ParamType> &listed, cellgrid_t &grid);
};
template <typename ParamType> using topology_cell = listed_force_cell<ParamType, ParamType::cell_cap>;
template <typename ParamType> using topology_grid = listed_force_grid<ParamType, ParamType::cell_cap>;
enum topo_types {
  TOP_HARMONIC_BOND,
  TOP_HARMONIC_ANGLE,
  TOP_UREY_BRADLY,
  TOP_COSINE_TORSION,
  TOP_RYCKAERT_BELLMANS,
  TOP_EXTENDED_RYCKAERT_BELLMANS,
  TOP_HARMONIC_IMPROPER,
  TOP_RIGID,
  TOP_SPECIAL,
  TOP_END
};

struct topology_grids {
  typedef utils::variant<topology_grid<harmonic_bond_param>,
                         topology_grid<harmonic_angle_param>,
                         topology_grid<urey_bradly_param>,
                         topology_grid<cosine_torsion_param>,
                         topology_grid<ryckaert_bellmans_param>,
                         topology_grid<extended_ryckaert_bellmans_param>,
                         topology_grid<harmonic_improper_param>,
                         topology_grid<rigid_param>,
                         topology_grid<special_pair>
                         > grid_variants;
  static constexpr int TOP_CALC_END = TOP_RIGID;
  static constexpr int TOP_GUEST_END = TOP_SPECIAL;
  grid_variants grids[TOP_END];
  bool present[TOP_END];

  topology_grids(){
    for (int i = 0; i < TOP_END; i ++) present[i] = 0;
  }
  template<int ITH>
  auto &get(){
    static_assert(ITH < TOP_END, "No such grid type");
    return grids[ITH].as<ITH>();
  }
  template<typename T>
  auto &get(){
    static_assert(grid_variants::variants::has<topology_grid<T>>, "No such grid type");
    return grids[utils::get_ord<grid_variants, topology_grid<T>>::value].template as<topology_grid<T>>();
  }

  template<typename ...Ts>
  void allocate_and_distribute(cellgrid_t &grid, esmd::htab<tag_key, vec<int>> &tagmap, listed_force<Ts> &...lists){
    utils::expand{
      (get<Ts>().allocate(grid), get<Ts>().distribute(tagmap, lists, grid), present[utils::get_ord<grid_variants, topology_grid<Ts>>::value] = true, 0)...
    };
  }

  template<typename TagMapType>
  void do_export(int offset, TagMapType &tagmap);
  template<typename TagMapType, int... Is>
  void do_export(int offset, TagMapType &tagmap, const utils::iseq<int, Is...> &seq);
  template<typename TagMapType>
  void relink(int offset, TagMapType &tagmap);
  template<typename TagMapType, int... Is>
  void relink(int offset, TagMapType &tagmap, const utils::iseq<int, Is...> &seq);

  template<typename TagMapType, int...Is>
  void do_import(int self, int from, TagMapType &tagmap, const utils::iseq<int, Is...> &seq);
  template<typename TagMapType>
  void do_import(int self, int from, TagMapType &tagmap);
  template<int I, typename TagMapType, typename GuestMapType>
  void do_import_neighbors(int self, TagMapType &tagmap, GuestMapType &guestmap, cellgrid_t *grid, int cx, int cy, int cz);
  template<typename TagMapType, typename GuestMapType, int ...Is>
  void do_import_neighbors(int self, TagMapType &tagmap, GuestMapType &guestmap, cellgrid_t *grid, int cx, int cy, int cz, utils::iseq<int, Is...> bseq);

  template<typename GuestMapType, typename LocalMapType, int ...Is>
  void find_guests(int offset, GuestMapType &guests, LocalMapType &locals, const utils::iseq<int, Is...> &seq);
  template<typename GuestMapType, typename LocalMapType>
  void find_guests(int offset, GuestMapType &guests, LocalMapType &locals);

  struct export_packer{
    export_packer(topology_grids &grid) : grid(grid){
    }
    topology_grids &grid;
    template<int ...Is>
    void pack(pack_buffer &buffer, int offset, const utils::iseq<int, Is...> &seq){
      utils::expand{
        (grid.present[Is] ? (grid.get<Is>().cells[offset].by_tag.pack_exported(buffer), 0) : 0)...
      };
    }
    void operator()(pack_buffer &buffer, int offset){
      static constexpr auto subseq = utils::make_iseq<int, TOP_END>{};
      pack(buffer, offset, subseq);
    }
  };
  struct export_unpacker{
    export_unpacker(topology_grids &grid) : grid(grid){
    }
    topology_grids &grid;
    template<int ...Is>
    void unpack(unpack_buffer &buffer, int offset, const utils::iseq<int, Is...> &seq){
      utils::expand{
        (grid.present[Is] ? (grid.get<Is>().cells[offset].by_tag.unpack_exported(buffer), 0) : 0)...
      };
    }
    void operator()(unpack_buffer &buffer, int offset){
      static constexpr auto subseq = utils::make_iseq<int, TOP_END>{};
      unpack(buffer, offset, subseq);
    }
  };

  void forward_comm_export(mpp_t *mpp, cellgrid_t *grid);
  void update (mpp_t *mpp, cellgrid_t *grid);
  void compute(mdstat_t *stat, cellgrid_t *grid);
};


void compute_listed_forces (cellgrid_t *grid, void *param, mdstat_t *stat);
template<typename RoutedType, typename BaseType, typename RouterType>
void route_pairs(listed_force<RoutedType> &routed, listed_force<BaseType> &base, listed_force<RouterType> &router);
